import csv
import sys
csv.field_size_limit(sys.maxsize)
import pandas as pd
import math
from dataclasses import dataclass, field
from typing import Optional
import random
import torch
import transformers
from tqdm import tqdm
from torch.multiprocessing import set_start_method
import numpy as np

def print_trainable_parameters(model):
	"""
	Prints the number of trainable parameters in the model.
	"""
	trainable_params = 0
	all_param = 0
	for _, param in model.named_parameters():
		all_param += param.numel()
		if param.requires_grad:
			trainable_params += param.numel()
	print(
		f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
	)

@dataclass
class ModelArguments:
	model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped")
	model_type: Optional[str] = field(default="llama")

@dataclass
class TrainingArguments(transformers.TrainingArguments):
	cache_dir: Optional[str] = field(default=None)
	optim: str = field(default="adamw_torch")
	file_name: str = field(default="response_test.csv")
	rank: int = field(default=0)
	model_max_length: int = field(
		default=8192 * 4,
		metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
	)
	use_flash_attn: bool = field(
		default=True,
		metadata={"help": "Whether use flash attention for training."},
	)

PREPEND = "Help achieve the objective by generating the next step."


def eval(rank, idx, model_args, training_args):
	seed = 0
	torch.cuda.empty_cache()
	torch.manual_seed(seed)
	np.random.seed(seed)
	random.seed(seed)
	torch.cuda.manual_seed_all(seed)

	# Set RoPE scaling factor
	config = transformers.AutoConfig.from_pretrained(
		model_args.model_name_or_path,
		cache_dir=training_args.cache_dir,
	)

	orig_rope_scaling = getattr(config, "rope_scaling", None)
	if orig_rope_scaling is None:
		orig_rope_scaling = {"factor": 1}

	orig_rope_scaling_factor = orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1
	orig_ctx_len = getattr(config, "max_position_embeddings", None)
	if orig_ctx_len:
		orig_ctx_len *= orig_rope_scaling_factor
		if training_args.model_max_length > orig_ctx_len:
			scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
			config.rope_scaling = {"factor": scaling_factor, "type": "yarn"}

	# Load model and tokenizer
	model = transformers.AutoModelForCausalLM.from_pretrained(
		model_args.model_name_or_path,
		config=config,
		cache_dir=training_args.cache_dir,
		torch_dtype=torch.bfloat16,
	)

	tokenizer = transformers.AutoTokenizer.from_pretrained(
		model_args.model_name_or_path,
		cache_dir=training_args.cache_dir,
		model_max_length=training_args.model_max_length,
		padding_side="left",
	)
	tokenizer.pad_token = tokenizer.eos_token
	tokenizer.pad_token_id = tokenizer.eos_token_id
	tokenizer.padding_side = "left"

	df = pd.read_csv(f"test_final.csv")
	df = df.iloc[idx].reset_index(drop=True)
	df["prompt"] = df["chunk"]
	df.rename(columns={"prompt": "label_ids", "target": "label"}, inplace=True)

	model.eval()
	model.load_adapter(training_args.output_dir)
	model.load_state_dict(torch.load(f"{training_args.output_dir}/model.pt"), strict=False)
	model.to("cuda")

	response = []
	for prompt in tqdm(df["label_ids"].to_list()):
		messages = [
			{"role": "system", "content": PREPEND},
			{"role": "user", "content": prompt}
		]
		input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
		model_inputs = tokenizer(input_text, return_tensors="pt")
		input = {}
		for key, value in model_inputs.items():
			input[key] = value.to("cuda").reshape(1, -1)

		input_len = input["input_ids"].shape[1]

		generated_ids = model.generate(**input,
									   max_new_tokens=500,
									   do_sample=True,
									   temperature=0.6,
									   top_p=0.95,
									   pad_token_id=tokenizer.eos_token_id,
									   )
		generated_ids = [generated_ids[0][input_len:]]
		generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
		response.append(generated_text)

	df["response"] = response
	df.to_csv(f"{training_args.file_name}_{rank}.csv", index=False)


if __name__ == "__main__":
	set_start_method('spawn', force=True)
	parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
	model_args, training_args = parser.parse_args_into_dataclasses()
	df = pd.read_csv(f"test_final.csv")
	idx = [i for i in range(len(df))]
	num_proc = 8
	idx = np.array_split(idx, num_proc)
	del df
	eval(training_args.rank, idx[training_args.rank], model_args, training_args)
